import pandas as pd
import numpy as np
import re
import statsmodels.formula.api as smf
from scipy import stats
import matplotlib.pyplot as plt

outdir = 'Directory where you want the results to go\\'
datadir = 'Directory where you stored data.\\'
#File with the regression design matrix.
regdat = pd.read_csv(datadir+'TotalData.csv',low_memory=False,dtype={'countyFIPS':'str'})
#Divide by 10 to create a GrowthFatRate consistent with what Heather used.
regdat.loc[:,[x for x in list(regdat) if 'GrowthFatRate' in x]] = regdat.loc[:,[x for x in list(regdat) if 'GrowthFatRate' in x]]/10
#Get the county pairs.
pairdat = pd.read_csv(datadir+'CountyHedonicNeighborsMile_100_NB_HD2.csv',low_memory=False,dtype={'countyFIPS':str,'neighborFIPS':str})

treatFIPS = sorted(list(set(pairdat['countyFIPS'])))
controlFIPS = sorted(list(set(pairdat['neighborFIPS'])))

regdatTreat = regdat.loc[regdat['countyFIPS'].isin(treatFIPS),:]
regdatControl = regdat.loc[regdat['countyFIPS'].isin(controlFIPS),:]
regdatTC = regdat.loc[regdat['countyFIPS'].isin(controlFIPS+treatFIPS),:]


regformula = 'GrowthFatRate ~ 1 + HHMF + HHSS + TEMPcoldMF + TEMPcoldSS + TEMP + PerCapIncome + Age65plus + Age85plus + Asian +\
Black + Hispanic + NativeAmer + Other + Diabetes + Obese + Smoke + HousingDensity + PopDensity + NurseTotPop + w + tMinusTrig +\
TotDeathsPerCap + GrowthFatRatelag1 + GrowthFatRatelag2+ GrowthFatRatelag3+\
GrowthFatRatelag4 + GrowthFatRatelag5 + GrowthFatRatelag6 + TotDeathsPerCap + dParksStart + dElectiveStart +\
dNursingHomeStart +  dBusOpenRevStart + dResidentMaskRecStart + dRisk2Closed + dNursingMustAcceptStart + dGatherMax10Start +\
dEmergencyStart + dGatherMax100Start + dGatherMax100plusStart + dStayHomeStart + dResidentMaskMandStart + dEmployeeMaskStart +\
dRisk4Closed + dRisk1Closed + dRisk3Closed + dbarCrestC + dbarCrestOpen + dGymsZeroStart + dSpasZeroStart'

#regdat['dRestaurantsBarsZeroStart'] = regdat[['dRestaurantsZeroStart','dBarsZeroStart']].min(axis='columns')
#regdat['wRestaurantsBarsZeroStart'] = regdat[['wRestaurantsZeroStart','wBarsZeroStart']].min(axis='columns')
#regdat['dBarsOnlyZeroStart'] = regdat['dBarsZeroStart']*(1-regdat['dRestaurantsZeroStart'])
#regdat['wBarsOnlyZeroStart'] = regdat['wBarsZeroStart']*(1-regdat['dRestaurantsZeroStart'])

#regformula = re.sub('BarsZeroStart','BarsOnlyZeroStart',regformula)
#regformula = re.sub('RestaurantsZeroStart','RestaurantsBarsZeroStart',regformula)


resultAll = smf.ols(regformula, regdatTreat).fit()

outorder = {
'wStayHomeStart':'Stay at Home', \
'wEmergencyStart':'State of Emergency', \
'wNursingMustAcceptStart':'Nursing Home Accept Pos.', \
'wNursingHomeStart':'No Nursing Home Visit', \
'wEmployeeMaskStart':'Employee Mask', \
'wResidentMaskRecStart':'Masks Recommended', \
'wResidentMaskMandStart':'Mandatory Masks', \
'wParksStart':'Parks Closed', \
'wElectiveStart':'No Elective Procedures', \
'wbarCrestC':'Bars Closed, Restaurants Closed', \
'wbarCrestOpen':'Bars Closed, Restaurants Open', \
'wGymsZeroStart':'Gyms Closed', \
'wSpasZeroStart':'Spas Closed', \
'wGatherMax10Start':'Gatherings limited to 10', \
'wGatherMax100Start':'No gatherings over 100', \
'wGatherMax100plusStart':'Gathering limit over 100', \
'wRisk1Closed':'Risk Level 1 to 4 Closed', \
'wBusOpenRevStart':'Bus Openings Reversed'}

outidx = []
[outidx.extend([outorder[x]+'t',outorder[x]+'c']) for x in outorder]
residByOrder = pd.DataFrame(index=outidx,columns=['Mean-4','Mean-3','Mean-2','Mean-1','Mean0','Mean1','Mean2','Mean3','Mean4','Mean5','Mean6'])
rawByOrder = pd.DataFrame(index=outidx,columns=['Raw-4','Raw-3','Raw-2','Raw-1','Raw0','Raw1','Raw2','Raw3','Raw4','Raw5','Raw6'])


for i in outorder:
        print(i)
        startIdx = regdatTC.loc[((regdatTC[i]>0) & (regdatTC[i]<1)) & regdatTC['countyFIPS'].isin(treatFIPS)].index
        regresid = re.sub('\+ '+i,'',regformula)
        result = smf.ols(regresid, regdatTC).fit()
        zz = result.resid
        zz = pd.DataFrame(zz,columns=['residual'])
        zz = pd.merge(regdatTC[['countyFIPS','Date']],zz,left_index=True,right_index=True,how='inner')
        rawG = regdatTC[['countyFIPS','Date','GrowthFatRate']]
        #zz = zz.loc[zz['countyFIPS'].isin(basedata['countyFIPS'])]
        startIdx = startIdx.intersection(zz.index)
        for x in range(-4,7):
              zz['res_'+str(x)] = zz['residual'].shift(-x)
              zz['FIPS'+str(x)] = zz['countyFIPS'].shift(-x)
              rawG['GrowthFatRate_'+str(x)] = rawG['GrowthFatRate'].shift(-x)
              rawG['FIPS'+str(x)] = rawG['countyFIPS'].shift(-x)              
        treat = zz.loc[startIdx,:]
        treatG = rawG.loc[startIdx,:]
        zzID = zz.loc[:,['countyFIPS','Date']]
        zzID = pd.merge(zzID,pairdat[['countyFIPS','neighborFIPS']],on='countyFIPS',how='left')
        control = pd.DataFrame()
        controlG = pd.DataFrame()
        for t in treat.index:
             nFIPS = pairdat.loc[pairdat['countyFIPS']==treat.loc[t,'countyFIPS'],'neighborFIPS']
             nFIPS = list(nFIPS)[0]
             control = control.append(zz.loc[(zz['countyFIPS']==nFIPS) & (zz['Date']==treat.loc[t,'Date']),:])
             controlG = controlG.append(rawG.loc[(rawG['countyFIPS']==nFIPS) & (rawG['Date']==treat.loc[t,'Date']),:])
        for x in range(-4,7):
             residWeekX = treat.loc[treat['countyFIPS']==treat['FIPS'+str(x)],'res_'+str(x)]
             outMean = residWeekX.mean(skipna=True)
             residByOrder.loc[outorder[i]+'t','Mean'+str(x)] = outMean
             residWeekX = control.loc[control['countyFIPS']==control['FIPS'+str(x)],'res_'+str(x)]
             outMean = residWeekX.mean(skipna=True)
             residByOrder.loc[outorder[i]+'c','Mean'+str(x)] = outMean             
             residWeekX = treatG.loc[treatG['countyFIPS']==treatG['FIPS'+str(x)],'GrowthFatRate_'+str(x)]
             outMean = residWeekX.mean(skipna=True)
             rawByOrder.loc[outorder[i]+'t','Raw'+str(x)] = outMean
             residWeekX = controlG.loc[controlG['countyFIPS']==controlG['FIPS'+str(x)],'GrowthFatRate_'+str(x)]
             outMean = residWeekX.mean(skipna=True)
             rawByOrder.loc[outorder[i]+'c','Raw'+str(x)] = outMean

print(residByOrder)
print(rawByOrder)

for i in outorder:
        fig, ax = plt.subplots()
        ax.plot(residByOrder.loc[outorder[i]+'c',:],label='Control')
        ax.plot(residByOrder.loc[outorder[i]+'t',:],label='Treatment')
        ax.set_xticklabels([str(x) for x in range(-4,7)])
        ax.set_xlabel('weeks')
        ax.set_ylabel('Growth')
        ax.legend(loc='best')
        ax.set_title(outorder[i],size = 20)
        plt.savefig(outdir+outorder[i]+'resid.png')
        fig, ax = plt.subplots()
        ax.plot(rawByOrder.loc[outorder[i]+'c',:],label='Control')
        ax.plot(rawByOrder.loc[outorder[i]+'t',:],label='Treatment')
        ax.set_xticklabels([str(x) for x in range(-4,7)])
        ax.set_xlabel('weeks')
        ax.set_ylabel('Growth')
        ax.legend(loc='best')
        ax.set_title(outorder[i],size = 20)
        plt.savefig(outdir+outorder[i]+'raw.png')

 